use crate::UnsafeSharedRef;
use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};

use alloc::{vec, vec::Vec};
use burn_tensor::ElementConversion;
use burn_tensor::Shape;
use ndarray::{IxDyn, s};

pub(crate) fn matmul<E: NdArrayElement>(
    lhs: SharedArray<E>,
    rhs: SharedArray<E>,
) -> SharedArray<E> {
    let shape_lhs = lhs.shape();
    let shape_rhs = rhs.shape();
    let ndims = shape_lhs.num_dims();
    let m = shape_lhs[ndims - 2]; // # of left rows
    let k = shape_rhs[ndims - 2]; // # of left cols and right rows
    let n = shape_rhs[ndims - 1]; // # of right cols

    let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
    let l_mat_size = m * k; // size of matrix component of left array
    let r_mat_size = k * n; // size of matrix component of right array
    let out_mat_size = m * n; // size of matrix component of output array

    let num_l_batches = shape_lhs.num_elements() / l_mat_size;
    let num_r_batches = shape_rhs.num_elements() / r_mat_size;
    let num_out_batches = out_shape.num_elements() / out_mat_size;

    let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
    let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));

    let alpha: E = 1.0.elem();
    let beta: E = 0.0.elem();

    let out = run_par!(|| {
        let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
        let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);

        iter_range_par!(0, num_out_batches).for_each(|out_batch| {
            // Here, we:
            //   1. Un-flatten the output batch into a component-based batch index.
            //   2. Use the strides for left and right batch indices to convert it to a flattened
            //      batch for left and right.
            let out_index = strides_out.unflatten(out_batch);
            let l_batch = strides_lhs.flatten(&out_index);
            let r_batch = strides_rhs.flatten(&out_index);

            let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
            let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));

            unsafe {
                let mut out_slice = unsafe_shared_out_array
                    .get()
                    .slice_mut(s!(out_batch, .., ..));

                ndarray::linalg::general_mat_mul(
                    alpha,
                    &lhs_slice,
                    &rhs_slice,
                    beta,
                    &mut out_slice,
                )
            }
        });

        out_array.into_shared().into_dyn()
    });

    NdArrayOps::reshape(out, out_shape)
}

#[derive(Debug, PartialEq)]
struct Strides {
    strides: Vec<usize>,
}
impl Strides {
    fn new(strides: Vec<usize>) -> Self {
        Strides { strides }
    }

    fn unflatten(&self, linear_index: usize) -> Vec<usize> {
        let mut coord = Vec::with_capacity(self.strides.len());
        let mut rem = linear_index;
        for stride in self.strides.iter() {
            coord.push(rem / stride);
            rem %= stride;
        }
        coord
    }

    fn flatten(&self, index: &Vec<usize>) -> usize {
        assert_eq!(self.strides.len(), index.len());
        self.strides
            .iter()
            .zip(index)
            .map(|(stride, index)| stride * index)
            .sum()
    }
}

/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for
/// the non-matrix dimensions of all arrays.
///
/// # Arguments
/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument.
/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument.
///
/// # Panics
/// * If `D` is not at least 2.
/// * If the matrix multiplication dimensions (last 2) are incompatible.
/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where
///   one dim is equal to 1 is broadcast.)
fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
    let ndims = lsh.num_dims();
    if ndims < 2 {
        panic!("Matrix multiplication requires an array with at least 2 dimensions.");
    }

    // Fetch matrix dimensions and check compatibility.
    let l_rows = lsh[ndims - 2];
    let l_cols = lsh[ndims - 1];
    let r_rows = rsh[ndims - 2];
    let r_cols = rsh[ndims - 1];
    if l_cols != r_rows {
        panic!("Dimensions are incompatible for matrix multiplication.");
    }
    // Set matrix dimensions of the output shape.
    let mut osh = vec![0; ndims];
    osh[ndims - 2] = l_rows;
    osh[ndims - 1] = r_cols;

    // Set other array dimensions, broadcasting as necessary.
    // Compute the strides inline.
    let mut cur_l_stride: usize = 1;
    let mut cur_r_stride: usize = 1;
    let mut cur_o_stride: usize = 1;
    let mut l_strides = Vec::with_capacity(ndims - 2);
    let mut r_strides = Vec::with_capacity(ndims - 2);
    let mut o_strides = Vec::with_capacity(ndims - 2);
    for i in (0..ndims - 2).rev() {
        let l_dim = lsh[i];
        let r_dim = rsh[i];

        // Compatible dimensions are:
        //   1. Both dimensions are equal.
        //   2. One of the dimensions is equal to 1.
        let o_dim: usize;
        if l_dim == r_dim {
            o_dim = l_dim; // both dimensions are equal
            l_strides.push(cur_l_stride);
            r_strides.push(cur_r_stride);
        } else if l_dim == 1 {
            o_dim = r_dim; // broadcast the left
            l_strides.push(0);
            r_strides.push(cur_r_stride);
        } else if r_dim == 1 {
            o_dim = l_dim; // broadcast the right
            l_strides.push(cur_l_stride);
            r_strides.push(0);
        } else {
            panic!("Dimensions differ and cannot be broadcasted.");
        }
        osh[i] = o_dim;
        o_strides.push(cur_o_stride);
        cur_o_stride *= o_dim;

        cur_l_stride *= l_dim;
        cur_r_stride *= r_dim;
    }
    l_strides.reverse();
    r_strides.reverse();
    o_strides.reverse();

    (
        Shape::from(osh),
        Strides::new(l_strides),
        Strides::new(r_strides),
        Strides::new(o_strides),
    )
}

pub(crate) fn cross<E: NdArrayElement>(
    lhs: SharedArray<E>,
    rhs: SharedArray<E>,
    dim: usize,
) -> SharedArray<E> {
    let shape_lhs = lhs.shape();
    let shape_rhs = rhs.shape();
    let ndims = shape_lhs.num_dims();

    // Broadcast the shapes except along dim
    let mut broadcast_shape = vec![0; ndims];
    for i in 0..ndims {
        if i == dim {
            broadcast_shape[i] = shape_lhs[i]; // already checked to be 3
        } else {
            let l = shape_lhs[i];
            let r = shape_rhs[i];
            if l == r {
                broadcast_shape[i] = l;
            } else if l == 1 {
                broadcast_shape[i] = r;
            } else if r == 1 {
                broadcast_shape[i] = l;
            } else {
                panic!("Tensors are not broadcastable along dimension {}", i);
            }
        }
    }

    // Broadcast lhs and rhs
    let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
        lhs
    } else {
        NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
    };
    let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
        rhs
    } else {
        NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
    };

    // Now, move dim to the last dimension
    let mut perm = (0..ndims).collect::<Vec<_>>();
    perm.remove(dim);
    perm.push(dim);

    let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
    let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);

    // Reshape to (*, 3)
    let total_elements = lhs_permuted.shape().num_elements();
    let batch_size = total_elements / 3;
    let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
    let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));

    // Compute cross product
    let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
    for i in 0..batch_size {
        let a1 = lhs_reshaped[IxDyn(&[i, 0])];
        let a2 = lhs_reshaped[IxDyn(&[i, 1])];
        let a3 = lhs_reshaped[IxDyn(&[i, 2])];
        let b1 = rhs_reshaped[IxDyn(&[i, 0])];
        let b2 = rhs_reshaped[IxDyn(&[i, 1])];
        let b3 = rhs_reshaped[IxDyn(&[i, 2])];
        result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
        result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
        result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
    }

    let result_shared = result.into_shared();

    // Reshape back to the broadcast shape with dim at the end
    let mut result_shape = broadcast_shape;
    result_shape.remove(dim);
    result_shape.push(3);
    let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));

    // Permute back
    let mut inv_perm = vec![0; ndims];
    for (i, &p) in perm.iter().enumerate() {
        inv_perm[p] = i;
    }
    NdArrayOps::permute(result_reshaped, &inv_perm)
}

#[cfg(test)]
mod tests {
    use super::*;

    impl Strides {
        fn empty() -> Self {
            Strides {
                strides: Vec::with_capacity(0),
            }
        }
    }

    #[test]
    fn test_output_shape() {
        // plain matrix multiply
        assert_eq!(
            output_shape(&[5, 3], &[3, 7]),
            (
                Shape::from([5, 7]),
                Strides::empty(),
                Strides::empty(),
                Strides::empty()
            )
        );
        // matrix multiply with one extra stack dimension
        assert_eq!(
            output_shape(&[4, 5, 3], &[4, 3, 7]),
            (
                Shape::from([4, 5, 7]),
                Strides::new(vec![1]),
                Strides::new(vec![1]),
                Strides::new(vec![1])
            )
        );
        // rank 3, broadcast left
        assert_eq!(
            output_shape(&[1, 5, 3], &[4, 3, 7]),
            (
                Shape::from([4, 5, 7]),
                Strides::new(vec![0]),
                Strides::new(vec![1]),
                Strides::new(vec![1])
            )
        );
        // rank 3, broadcast right
        assert_eq!(
            output_shape(&[4, 5, 3], &[1, 3, 7]),
            (
                Shape::from([4, 5, 7]),
                Strides::new(vec![1]),
                Strides::new(vec![0]),
                Strides::new(vec![1])
            )
        );
        // rank 4, multi broadcast
        assert_eq!(
            output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
            (
                Shape::from([8, 4, 5, 7]),
                Strides::new(vec![0, 1]),
                Strides::new(vec![1, 0]),
                Strides::new(vec![4, 1])
            )
        );
        // rank 5, multi-broadcast
        assert_eq!(
            output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
            (
                Shape::from([8, 3, 4, 5, 7]),
                Strides::new(vec![0, 4, 1]),
                Strides::new(vec![3, 1, 0]),
                Strides::new(vec![12, 4, 1])
            )
        )
    }

    #[test]
    #[should_panic(
        expected = "Matrix multiplication requires an array with at least 2 dimensions."
    )]
    fn test_output_shape_too_small() {
        output_shape(&[4], &[4]);
    }

    #[test]
    #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
    fn test_output_shape_bad_matrix_dims() {
        output_shape(&[5, 3], &[4, 7]);
    }

    #[test]
    #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
    fn test_output_shape_non_broadcast() {
        output_shape(&[4, 5, 3], &[2, 3, 7]);
    }
}
